// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package eventbustest

import (
	"errors"
	"fmt"
	"reflect"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"tailscale.com/util/eventbus"
)

// NewBus constructs an [eventbus.Bus] that will be shut automatically when
// its controlling test ends.
func NewBus(t testing.TB) *eventbus.Bus {
	bus := eventbus.New()
	t.Cleanup(bus.Close)
	return bus
}

// NewWatcher constructs a [Watcher] that can be used to check the stream of
// events generated by code under test. After construction the caller may use
// [Expect] and [ExpectExactly], to verify that the desired events were captured.
func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher {
	tw := &Watcher{
		mon:    bus.Debugger().WatchBus(),
		chDone: make(chan bool, 1),
		events: make(chan any, 100),
	}
	t.Cleanup(tw.done)
	go tw.watch()
	return tw
}

// Watcher monitors and holds events for test expectations.
// The Watcher works with [synctest], and some scenarios does require the use of
// [synctest]. This is amongst others true if you are testing for the absence of
// events.
//
// For usage examples, see the documentation in the top of the package.
type Watcher struct {
	mon    *eventbus.Subscriber[eventbus.RoutedEvent]
	events chan any
	chDone chan bool
}

// Type is a helper representing the expectation to see an event of type T, without
// caring about the content of the event.
// It makes it possible to use helpers like:
//
//	eventbustest.ExpectFilter(tw, eventbustest.Type[EventFoo]())
func Type[T any]() func(T) { return func(T) {} }

// Expect verifies that the given events are a subsequence of the events
// observed by tw. That is, tw must contain at least one event matching the type
// of each argument in the given order, other event types are allowed to occur in
// between without error. The given events are represented by a function
// that must have one of the following forms:
//
//	// Tests for the event type only
//	func(e ExpectedType)
//
//	// Tests for event type and whatever is defined in the body.
//	// If return is false, the test will look for other events of that type
//	// If return is true, the test will look for the next given event
//	// if a list is given
//	func(e ExpectedType) bool
//
//	// Tests for event type and whatever is defined in the body.
//	// The boolean return works as above.
//	// The if error != nil, the test helper will return that error immediately.
//	func(e ExpectedType) (bool, error)
//
//	// Tests for event type and whatever is defined in the body.
//	// If a non-nil error is reported, the test helper will return that error
//	// immediately; otherwise the expectation is considered to be met.
//	func(e ExpectedType) error
//
// If the list of events must match exactly with no extra events,
// use [ExpectExactly].
func Expect(tw *Watcher, filters ...any) error {
	if len(filters) == 0 {
		return errors.New("no event filters were provided")
	}
	eventCount := 0
	head := 0
	for head < len(filters) {
		eventFunc := eventFilter(filters[head])
		select {
		case event := <-tw.events:
			eventCount++
			if ok, err := eventFunc(event); err != nil {
				return err
			} else if ok {
				head++
			}
		// Use synctest when you want an error here.
		case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock
			return fmt.Errorf(
				"timed out waiting for event, saw %d events, %d was expected",
				eventCount, len(filters))
		case <-tw.chDone:
			return errors.New("watcher closed while waiting for events")
		}
	}
	return nil
}

// ExpectExactly checks for some number of events showing up on the event bus
// in a given order, returning an error if the events does not match the given list
// exactly. The given events are represented by a function as described in
// [Expect]. Use [Expect] if other events are allowed.
//
// If you are expecting ExpectExactly to fail because of a missing event, or if
// you are testing for the absence of events, call [synctest.Wait] after
// actions that would publish an event, but before calling ExpectExactly.
func ExpectExactly(tw *Watcher, filters ...any) error {
	if len(filters) == 0 {
		select {
		case event := <-tw.events:
			return fmt.Errorf("saw event type %s, expected none", reflect.TypeOf(event))
		case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock
			return nil
		}
	}
	eventCount := 0
	for pos, next := range filters {
		eventFunc := eventFilter(next)
		fnType := reflect.TypeOf(next)
		argType := fnType.In(0)
		select {
		case event := <-tw.events:
			eventCount++
			typeEvent := reflect.TypeOf(event)
			if typeEvent != argType {
				return fmt.Errorf(
					"expected event type %s, saw %s, at index %d",
					argType, typeEvent, pos)
			} else if ok, err := eventFunc(event); err != nil {
				return err
			} else if !ok {
				return fmt.Errorf(
					"expected test ok for type %s, at index %d", argType, pos)
			}
		case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock
			return fmt.Errorf(
				"timed out waiting for event, saw %d events, %d was expected",
				eventCount, len(filters))
		case <-tw.chDone:
			return errors.New("watcher closed while waiting for events")
		}
	}
	return nil
}

func (tw *Watcher) watch() {
	for {
		select {
		case event := <-tw.mon.Events():
			tw.events <- event.Event
		case <-tw.mon.Done():
			tw.done()
			return
		case <-tw.chDone:
			tw.mon.Close()
			return
		}
	}
}

// done tells the watcher to stop monitoring for new events.
func (tw *Watcher) done() {
	close(tw.chDone)
}

type filter = func(any) (bool, error)

func eventFilter(f any) filter {
	ft := reflect.TypeOf(f)
	if ft.Kind() != reflect.Func {
		panic("filter is not a function")
	} else if ft.NumIn() != 1 {
		panic(fmt.Sprintf("function takes %d arguments, want 1", ft.NumIn()))
	}
	var fixup func([]reflect.Value) []reflect.Value
	switch ft.NumOut() {
	case 0:
		fixup = func([]reflect.Value) []reflect.Value {
			return []reflect.Value{reflect.ValueOf(true), reflect.Zero(reflect.TypeFor[error]())}
		}
	case 1:
		switch ft.Out(0) {
		case reflect.TypeFor[bool]():
			fixup = func(vals []reflect.Value) []reflect.Value {
				return append(vals, reflect.Zero(reflect.TypeFor[error]()))
			}
		case reflect.TypeFor[error]():
			fixup = func(vals []reflect.Value) []reflect.Value {
				pass := vals[0].IsZero()
				return append([]reflect.Value{reflect.ValueOf(pass)}, vals...)
			}
		default:
			panic(fmt.Sprintf("result is %v, want bool or error", ft.Out(0)))
		}
	case 2:
		if ft.Out(0) != reflect.TypeFor[bool]() || ft.Out(1) != reflect.TypeFor[error]() {
			panic(fmt.Sprintf("results are %v, %v; want bool, error", ft.Out(0), ft.Out(1)))
		}
		fixup = func(vals []reflect.Value) []reflect.Value { return vals }
	default:
		panic(fmt.Sprintf("function returns %d values", ft.NumOut()))
	}
	fv := reflect.ValueOf(f)
	return reflect.MakeFunc(reflect.TypeFor[filter](), func(args []reflect.Value) []reflect.Value {
		if !args[0].IsValid() || args[0].Elem().Type() != ft.In(0) {
			return []reflect.Value{reflect.ValueOf(false), reflect.Zero(reflect.TypeFor[error]())}
		}
		return fixup(fv.Call([]reflect.Value{args[0].Elem()}))
	}).Interface().(filter)
}

// Injector holds a map with [eventbus.Publisher], tied to an [eventbus.Client]
// for testing purposes.
type Injector struct {
	client     *eventbus.Client
	publishers map[reflect.Type]any
	// The value for a key is an *eventbus.Publisher[T] for the corresponding type.
}

// NewInjector constructs an [Injector] that can be used to inject events into
// the the stream of events used by code under test. After construction the
// caller may use [Inject] to insert events into the bus.
func NewInjector(t *testing.T, b *eventbus.Bus) *Injector {
	inj := &Injector{
		client:     b.Client(t.Name()),
		publishers: make(map[reflect.Type]any),
	}
	t.Cleanup(inj.client.Close)

	return inj
}

// Inject inserts events of T onto an [eventbus.Bus]. If an [eventbus.Publisher]
// for the type does not exist, it will be initialized lazily. Calling inject is
// synchronous, and the event will as such have been published to the eventbus
// by the time the function returns.
func Inject[T any](inj *Injector, event T) {
	eventType := reflect.TypeFor[T]()

	pub, ok := inj.publishers[eventType]
	if !ok {
		pub = eventbus.Publish[T](inj.client)
		inj.publishers[eventType] = pub
	}
	pub.(*eventbus.Publisher[T]).Publish(event)
}

// EqualTo returns an event-matching function for use with [Expect] and
// [ExpectExactly] that matches on an event of the given type that is equal to
// want by comparison with [cmp.Diff]. The expectation fails with an error
// message including the diff, if present.
func EqualTo[T any](want T) func(T) error {
	return func(got T) error {
		if diff := cmp.Diff(got, want); diff != "" {
			return fmt.Errorf("wrong result (-got, +want):\n%s", diff)
		}
		return nil
	}
}

// LogAllEvents logs summaries of all the events routed via the specified bus
// during the execution of the test governed by t. This is intended to support
// development and debugging of tests.
func LogAllEvents(t testing.TB, bus *eventbus.Bus) {
	dw := bus.Debugger().WatchBus()
	done := make(chan struct{})
	go func() {
		defer close(done)
		var i int
		for {
			select {
			case <-dw.Done():
				return
			case re := <-dw.Events():
				i++
				t.Logf("[eventbus] #%[1]d: %[2]T | %+[2]v", i, re.Event)
			}
		}
	}()
	t.Cleanup(func() { dw.Close(); <-done })
}
